CS236605: Deep Learning on Computational Accelerators

Homework Assignment 3

Faculty of Computer Science, Technion.

Submitted by:

# Name Id email
Student 1 [your name here] [your id here] [your email here]
Student 2 [your name here] [your id here] [your email here]

Introduction

In this assignment we'll learn to generate text with a deep multilayer RNN network based on GRU cells. Then we'll focus our attention on image generation and implement two different generative models: A variational autoencoder and a generative adversarial network.

General Guidelines

  • Please read the getting started page on the course website. It explains how to setup, run and submit the assignment.
  • This assignment requires running on GPU-enabled hardware. Please read the course servers usage guide. It explains how to use and run your code on the course servers to benefit from training with GPUs.
  • The text and code cells in these notebooks are intended to guide you through the assignment and help you verify your solutions. The notebooks do not need to be edited at all (unless you wish to play around). The only exception is to fill your name(s) in the above cell before submission. Please do not remove sections or change the order of any cells.
  • All your code (and even answers to questions) should be written in the files within the python package corresponding the assignment number (hw1, hw2, etc). You can of course use any editor or IDE to work on these files.
$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bb}[1]{\boldsymbol{#1}} $$

Part 1: Sequence Models

In this part we will learn about working with text sequences using recurrent neural networks. We'll go from a raw text file all the way to a fully trained GRU-RNN model and generate works of art!

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Text generation with a char-level RNN

Obtaining the corpus

Let's begin by downloading a corpus containing all the works of William Shakespeare. Since he was very prolific, this corpus is fairly large and will provide us with enough data for obtaining impressive results.

In [2]:
CORPUS_URL = 'https://github.com/cedricdeboom/character-level-rnn-datasets/raw/master/datasets/shakespeare.txt'
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')

def download_corpus(out_path=DATA_DIR, url=CORPUS_URL, force=False):
    pathlib.Path(out_path).mkdir(exist_ok=True)
    out_filename = os.path.join(out_path, os.path.basename(url))
    
    if os.path.isfile(out_filename) and not force:
        print(f'Corpus file {out_filename} exists, skipping download.')
    else:
        print(f'Downloading {url}...')
        with urllib.request.urlopen(url) as response, open(out_filename, 'wb') as out_file:
            shutil.copyfileobj(response, out_file)
        print(f'Saved to {out_filename}.')
    return out_filename
    
corpus_path = download_corpus()
Corpus file /home/thaer/.pytorch-datasets/shakespeare.txt exists, skipping download.

Load the text into memory and print a snippet:

In [3]:
with open(corpus_path, 'r') as f:
    corpus = f.read()

print(f'Corpus length: {len(corpus)} chars')
print(corpus[7:1234])
Corpus length: 6347703 chars
ALLS WELL THAT ENDS WELL

by William Shakespeare

Dramatis Personae

  KING OF FRANCE
  THE DUKE OF FLORENCE
  BERTRAM, Count of Rousillon
  LAFEU, an old lord
  PAROLLES, a follower of Bertram
  TWO FRENCH LORDS, serving with Bertram

  STEWARD, Servant to the Countess of Rousillon
  LAVACHE, a clown and Servant to the Countess of Rousillon
  A PAGE, Servant to the Countess of Rousillon

  COUNTESS OF ROUSILLON, mother to Bertram
  HELENA, a gentlewoman protected by the Countess
  A WIDOW OF FLORENCE.
  DIANA, daughter to the Widow

  VIOLENTA, neighbour and friend to the Widow
  MARIANA, neighbour and friend to the Widow

  Lords, Officers, Soldiers, etc., French and Florentine  

SCENE:
Rousillon; Paris; Florence; Marseilles

ACT I. SCENE 1.
Rousillon. The COUNT'S palace

Enter BERTRAM, the COUNTESS OF ROUSILLON, HELENA, and LAFEU, all in black

  COUNTESS. In delivering my son from me, I bury a second husband.
  BERTRAM. And I in going, madam, weep o'er my father's death anew;
    but I must attend his Majesty's command, to whom I am now in
    ward, evermore in subjection.
  LAFEU. You shall find of the King a husband, madam; you, sir, a
    father. He that so generally is at all times good must of
    

Data Preprocessing

The first thing we'll need is to map from each unique character in the corpus to an index that will represent it in our learning process.

TODO: Implement the char_maps() function in the hw3/charnn.py module.

In [4]:
import hw3.charnn as charnn

char_to_idx, idx_to_char = charnn.char_maps(corpus)
print(char_to_idx)

test.assertEqual(len(char_to_idx), len(idx_to_char))
test.assertSequenceEqual(list(char_to_idx.keys()), list(idx_to_char.values()))
test.assertSequenceEqual(list(char_to_idx.values()), list(idx_to_char.keys()))
{'\n': 0, ' ': 1, '!': 2, '"': 3, '$': 4, '&': 5, "'": 6, '(': 7, ')': 8, ',': 9, '-': 10, '.': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21, ':': 22, ';': 23, '<': 24, '?': 25, 'A': 26, 'B': 27, 'C': 28, 'D': 29, 'E': 30, 'F': 31, 'G': 32, 'H': 33, 'I': 34, 'J': 35, 'K': 36, 'L': 37, 'M': 38, 'N': 39, 'O': 40, 'P': 41, 'Q': 42, 'R': 43, 'S': 44, 'T': 45, 'U': 46, 'V': 47, 'W': 48, 'X': 49, 'Y': 50, 'Z': 51, '[': 52, ']': 53, '_': 54, 'a': 55, 'b': 56, 'c': 57, 'd': 58, 'e': 59, 'f': 60, 'g': 61, 'h': 62, 'i': 63, 'j': 64, 'k': 65, 'l': 66, 'm': 67, 'n': 68, 'o': 69, 'p': 70, 'q': 71, 'r': 72, 's': 73, 't': 74, 'u': 75, 'v': 76, 'w': 77, 'x': 78, 'y': 79, 'z': 80, '}': 81, '\ufeff': 82}

Seems we have some strange characters in the corpus that are very rare and are probably due to mistakes. To reduce the length of each tensor we'll need to later represent our chars, it's best to remove them.

TODO: Implement the remove_chars() function in the hw3/charnn.py module.

In [5]:
corpus, n_removed = charnn.remove_chars(corpus, ['}','$','_','<','\ufeff'])
print(f'Removed {n_removed} chars')

# After removing the chars, re-create the mappings
char_to_idx, idx_to_char = charnn.char_maps(corpus)
Removed 34 chars

The next thing we need is an embedding of the chracters. An embedding is a representation of each token from the sequence as a tensor. For a char-level RNN, our tokens will be chars and we can thus use the simplest possible embedding: encode each char as a one-hot tensor. In other words, each char will be represented as a tensor whos length is the total number of unique chars (V) which contains all zeros except at the index corresponding to that specific char.

TODO: Implement the functions chars_to_onehot() and onehot_to_chars() in the hw3/charnn.py module.

In [6]:
# Wrap the actual embedding functions for calling convenience
def embed(text):
    return charnn.chars_to_onehot(text, char_to_idx)

def unembed(embedding):
    return charnn.onehot_to_chars(embedding, idx_to_char)

text_snippet = corpus[3104:3148]
print(text_snippet)
print(embed(text_snippet[0:3]))

test.assertEqual(text_snippet, unembed(embed(text_snippet)))
test.assertEqual(embed(text_snippet).dtype, torch.int8)
brine a maiden can season her praise in.
   
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0]], dtype=torch.int8)

Dataset Creation

We wish to train our model to generate text by constantly predicting what the next char should be based on the past. To that end we'll need to train our recurrent network in a way similar to a classification task. At each timestep, we input a char and set the expected output (label) to be the next char in the original sequence.

We will split our corpus into shorter sequences of length S chars (try to think why; see question below). Each sample we provide our model with will therefore be a tensor of shape (S,V) where V is the embedding dimension. Our model will operate sequentially on each char in the sequence. For each sample, we'll also need a label. This is simple another sequence, shifted by one char so that the label of each char is the next char in the corpus.

TODO: Implement the chars_to_labelled_samples() function in the hw3/charnn.py module.

In [7]:
# Create dataset of sequences
seq_len = 64
vocab_len = len(char_to_idx)

# Create labelled samples
samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)
print(f'samples shape: {samples.shape}')
print(f'labels shape: {labels.shape}')

# Test shapes
num_samples = (len(corpus) - 1) // seq_len
test.assertEqual(samples.shape, (num_samples, seq_len, vocab_len))
test.assertEqual(labels.shape, (num_samples, seq_len))

# Test content
for _ in range(1000):
    # random sample
    i = np.random.randint(num_samples, size=(1,))[0]
    # Compare to corpus
    test.assertEqual(unembed(samples[i]), corpus[i*seq_len:(i+1)*seq_len], msg=f"content mismatch in sample {i}")
    # Compare to labels
    sample_text = unembed(samples[i])
    label_text = str.join('', [idx_to_char[j.item()] for j in labels[i]])
    test.assertEqual(sample_text[1:], label_text[0:-1], msg=f"label mismatch in sample {i}")
    
print(f'sample 100 as text:\n{unembed(samples[100])}')
samples shape: torch.Size([99182, 64, 78])
labels shape: torch.Size([99182, 64])
sample 100 as text:
nity, though valiant in the
    defence, yet is weak. Unfold to 

As usual, instead of feeding one sample as a time into our model's forward we'll work with batches of samples. This means that at every timestep, our model will operate on a batch of chars that are from different sequences. Effectively this will allow us to parallelize training our model by dong matrix-matrix multiplications instead of matrix-vector during the forward pass.

Let's use the standard PyTorch Dataset/DataLoader combo. Luckily for the dataset we can use a built-in class, TensorDataset to return tuples of (sample, label) from the samples and labels tensors we created above.

In [8]:
import torch.utils.data

# Create DataLoader returning batches of samples.
batch_size = 32

ds_corpus = torch.utils.data.TensorDataset(samples, labels)
dl_corpus = torch.utils.data.DataLoader(ds_corpus, batch_size=batch_size, shuffle=False)

Let's see what that gives us:

In [9]:
print(f'num batches: {len(dl_corpus)}')

x0, y0 = next(iter(dl_corpus))
print(f'shape of a batch sample: {x0.shape}')
print(f'shape of a batch label: {y0.shape}')
num batches: 3100
shape of a batch sample: torch.Size([32, 64, 78])
shape of a batch label: torch.Size([32, 64])

Model Implementation

Finally, our data set is ready so we can focus on our model.

We'll implement here is a multilayer gated recurrent unit (GRU) model, with dropout. This model is a type of RNN which performs similar to the well-known LSTM model, but it's somewhat easier to train because it has less parameters. We'll modify the regular GRU slightly by applying dropout to the hidden states passed between layers of the model.

The model accepts an input $\mat{X}\in\set{R}^{S\times V}$ containing a sequence of embedded chars. It returns an output $\mat{Y}\in\set{R}^{S\times V}$ of predictions for the next char and the final hidden state $\mat{H}\in\set{R}^{L\times H}$. Here $S$ is the sequence length, $V$ is the vocabulary size (number of unique chars), $L$ is the number of layers in the model and $H$ is the hidden dimension.

Mathematically, the model's forward function at layer $k\in[1,L]$ and timestep $t\in[1,S]$ can be described as

$$ \begin{align} \vec{z_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xz}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hz}}}^{[k]} + \vec{b}_{\mathrm{z}}^{[k]}\right) \\ \vec{r_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xr}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hr}}}^{[k]} + \vec{b}_{\mathrm{r}}^{[k]}\right) \\ \vec{g_t}^{[k]} &= \tanh\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xg}}}^{[k]} + (\vec{r_t}^{[k]}\odot\vec{h}_{t-1}^{[k]}) {\mattr{W}_{\mathrm{hg}}}^{[k]} + \vec{b}_{\mathrm{g}}^{[k]}\right) \\ \vec{h_t}^{[k]} &= \vec{z}^{[k]}_t \odot \vec{h}^{[k]}_{t-1} + \left(1-\vec{z}^{[k]}_t\right)\odot \vec{g_t}^{[k]} \end{align} $$

The input to each layer is, $$ \mat{X}^{[k]} = \begin{bmatrix} {\vec{x}_1}^{[k]} \ \vdots \ {\vec{x}_S}^{[k]}

\end{bmatrix}

\begin{cases} \mat{X} & \mathrm{if} ~k = 1~ \\ \mathrm{dropout}_p \left( \begin{bmatrix} {\vec{h}_1}^{[k-1]} \\ \vdots \\ {\vec{h}_S}^{[k-1]} \end{bmatrix} \right) & \mathrm{if} ~1 < k \leq L+1~ \end{cases}

. $$

The output of the entire model is then, $$ \mat{Y} = \mat{X}^{[L+1]} {\mattr{W}_{\mathrm{hy}}} + \mat{B}_{\mathrm{y}} $$

and the final hidden state is $$ \mat{H} = \begin{bmatrix} {\vec{h}_S}^{[1]} \\ \vdots \\ {\vec{h}_S}^{[L]} \end{bmatrix}. $$

Notes:

  • $t\in[1,S]$ is the timestep, i.e. the current position within the sequence of each sample.
  • $\vec{x}_t^{[k]}$ is the input of layer $k$ at timestep $t$, respectively.
  • The outputs of the last layer $\vec{y}_t^{[L]}$, are the predicted next characters for every input char. These are similar to class scores in classification tasks.
  • The hidden states at the last timestep, $\vec{h}_S^{[k]}$, are the final hidden state returned from the model.
  • $\sigma(\cdot)$ is the sigmoid function, i.e. $\sigma(\vec{z}) = 1/(1+e^{-\vec{z}})$ which returns values in $(0,1)$.
  • $\tanh(\cdot)$ is the hyperbolic tangent, i.e. $\tanh(\vec{z}) = (e^{2\vec{z}}-1)/(e^{2\vec{z}}+1)$ which returns values in $(-1,1)$.
  • $\vec{h_t}^{[k]}$ is the hidden state of layer $k$ at time $t$. This can be thought of as the memory of that layer.
  • $\vec{g_t}^{[k]}$ is the candidate hidden state for time $t+1$.
  • $\vec{z_t}^{[k]}$ is known as the update gate. It combines the previous state with the input to determine how much the current state will be combined with the new candidate state. For example, if $\vec{z_t}^{[k]}=\vec{1}$ then the current input has no effect on the output.
  • $\vec{r_t}^{[k]}$ is known as the reset gate. It combines the previous state with the input to determine how much of the previous state will affect the current state candidate. For example if $\vec{r_t}^{[k]}=\vec{0}$ the previous state has no effect on the current candidate state.

Here's a graphical representation of the GRU's forward pass at each timestep. The $\vec{\tilde{h}}$ in the image is our $\vec{g}$ (candidate next state).

You can see how the reset and update gates allow the model to completely ignore it's previous state, completely ignore it's input, or any mixture of those states (since the gates are actually continuous and between $(0,1)$).

Here's a graphical representation of the entire model. You can ignore the $c_t^{[k]}$ (cell state) variables (which are relevant for LSTM models). Our model has only the hidden state, $h_t^{[k]}$. Also notice that we added dropout between layers (the up arrows).

The purple tensors are inputs (a sequence and initial hidden state per layer), and the green tensors are outputs (another sequence and final hidden state per layer). Each blue block implements the above forward equations. Blocks that are on the same vertical level are at the same layer, and therefore share parameters.

TODO: Implement the MultilayerGRU class in the hw3/charnn.py module.

Notes:

  • You'll need to handle input batches now. The math is identical to the above, but all the tensors will have an extra batch dimension as their first dimension.
  • Use the diagram above to help guide your implementation. It will help you visualize what shapes to returns where, etc.
In [10]:
in_dim = vocab_len
h_dim = 256
n_layers = 2
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers)
model = model.to(device)
print(model)

# Test forward pass
y, h = model(x0.to(dtype=torch.float))
print(f'y.shape={y.shape}')
print(f'h.shape={h.shape}')

test.assertEqual(y.shape, (batch_size, seq_len, vocab_len))
test.assertEqual(h.shape, (batch_size, n_layers, h_dim))
test.assertEqual(len(list(model.parameters())), 9 * n_layers + 2) 
MultilayerGRU(
  (layer 0 xz): Linear(in_features=78, out_features=256, bias=False)
  (layer 0 hz): Linear(in_features=256, out_features=256, bias=True)
  (layer 0 xr): Linear(in_features=78, out_features=256, bias=False)
  (layer 0 hr): Linear(in_features=256, out_features=256, bias=True)
  (layer 0 xg): Linear(in_features=78, out_features=256, bias=False)
  (layer 0 hg): Linear(in_features=256, out_features=256, bias=True)
  (layer 1 xz): Linear(in_features=256, out_features=256, bias=False)
  (layer 1 hz): Linear(in_features=256, out_features=256, bias=True)
  (layer 1 xr): Linear(in_features=256, out_features=256, bias=False)
  (layer 1 hr): Linear(in_features=256, out_features=256, bias=True)
  (layer 1 xg): Linear(in_features=256, out_features=256, bias=False)
  (layer 1 hg): Linear(in_features=256, out_features=256, bias=True)
  (output_layer): Linear(in_features=256, out_features=78, bias=True)
)
y.shape=torch.Size([32, 64, 78])
h.shape=torch.Size([32, 2, 256])

Generating text by sampling

Now that we have a model, we can implement text generation based on it. The idea is simple: At each timestep our model receives one char $x_t$ from the input sequence and outputs scores $y_t$ for what the next char should be. We'll convert these scores into a probability over each of the possible chars. In other words, for each input char $x_t$ we create a probability distribution for the next char conditioned on the current one and the state of the model (representing all previous inputs): $$p(x_{t+1}|x_t; \vec{h}_t).$$

Once we have such a distribution, we'll sample a char from it. This will be the first char of our generated sequence. Now we can feed this new char into the model, create another distribution, sample the next char and so on. Note that it's crucial to propagate the hidden state when sampling.

The important point however is how to create the distribution from the scores. One way, as we saw in previous ML tasks, is to use the softmax function. However, a drawback of softmax is that it can generate very diffuse (more uniform) distributions if the score values are very similar. When sampling, we would prefer to control the distributions and make them less uniform to increase the chance of sampling the char(s) with the highest scores compared to the others.

To control the variance of the distribution, a common trick is to add a hyperparameter $T$, known as the temperature to the softmax function. The class scores are simply scaled by $T$ before softmax is applied: $$ \mathrm{softmax}_T(\vec{y}) = \frac{e^{\vec{y}/T}}{\sum_k e^{y_k/T}} $$

A low $T$ will result in less uniform distributions and vice-versa.

TODO: Implement the hot_softmax() function in the hw3/charnn.py module.

In [11]:
scores = y[0,0,:].detach()
_, ax = plt.subplots(figsize=(15,5))

for t in reversed([0.3, 0.5, 1.0, 100]):
    ax.plot(charnn.hot_softmax(scores, temperature=t).cpu().numpy(), label=f'T={t}')
ax.set_xlabel('$x_{t+1}$')
ax.set_ylabel('$p(x_{t+1}|x_t)$')
ax.legend()

uniform_proba = 1/len(char_to_idx)
uniform_diff = torch.abs(charnn.hot_softmax(scores, temperature=100) - uniform_proba)
test.assertTrue(torch.all(uniform_diff < 1e-4))

TODO: Implement the generate_from_model() function in the hw3/charnn.py module.

In [12]:
for _ in range(3):
    text = charnn.generate_from_model(model, "foobar", 50, (char_to_idx, idx_to_char), T=0.5)
    print(text)
    test.assertEqual(len(text), 50)
foobarVkGzwZSkZc!DtFfK;EsLSLJLTD)wsf
BiuvRzRcQB&t9
foobarJ6SoLX4JK)cVZW5xpOzN'CB2-A6,
pa28m,Xqne;ki0q
foobar0&
]t3wIl)uD(c8-Gw!AuKoe[k.Hj.ZGNVizrTF!W!r5

Training

To train such a model, we'll calculate the loss at each time step by comparing the predicted char to the actual char from our label. We can use cross entropy since per char it's similar to a classification problem. We'll then sum the losses over the sequence and back-propagate the gradients though time. Notice that the back-propagation algorithm will "visit" each layer's parameter tensors multiple times, so we'll accumulate gradients in parameters of the blocks. Luckily autograd will handle this part for us.

As usual, the first step of training will be to try and overfit a large model (many parameters) to a tiny dataset. Again, this is to ensure the model and training code are implemented correctly, i.e. that the model can learn.

For a generative model such as this, overfitting is slightly trickier than for for classification. What we'll aim to do is to get our model to memorize a specific sequence of chars, so that when given the first char in the sequence it will immediately spit out the rest of the sequence verbatim.

Let's create a tiny dataset to memorize.

In [13]:
# Pick a tiny subset of the dataset
subset_start, subset_end = 1001, 1005
ds_corpus_ss = torch.utils.data.Subset(ds_corpus, range(subset_start, subset_end))
dl_corpus_ss = torch.utils.data.DataLoader(ds_corpus_ss, batch_size=1, shuffle=False)

# Convert subset to text
subset_text = ''
for i in range(subset_end - subset_start):
    subset_text += unembed(ds_corpus_ss[i][0])
print(f'Text to "memorize":\n\n{subset_text}')
Text to "memorize":

TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    I would not tell you what I would, my lord.
    Faith, yes:
    Strangers and foes do sunder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HE

Now let's implement the first part of our training code.

TODO: Implement the train_epoch() and train_batch() methods of the RNNTrainer class in the hw3/training.py module. Note: Think about how to correctly handle the hidden state of the model between batches and epochs (for this specific task, i.e. text generation).

In [14]:
import torch.nn as nn
import torch.optim as optim
from hw3.training import RNNTrainer

torch.manual_seed(42)

lr = 0.01
num_epochs = 500

in_dim = vocab_len
h_dim = 128
n_layers = 2
loss_fn = nn.CrossEntropyLoss()
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
trainer = RNNTrainer(model, loss_fn, optimizer, device)

for epoch in range(num_epochs):
    epoch_result = trainer.train_epoch(dl_corpus_ss, verbose=False)
    
    # Every X epochs, we'll generate a sequence starting from the first char in the first sequence
    # to visualize how/if/what the model is learning.
    if epoch == 0 or (epoch+1) % 25 == 0:
        avg_loss = np.mean(epoch_result.losses)
        accuracy = np.mean(epoch_result.accuracy)
        print(f'\nEpoch #{epoch+1}: Avg. loss = {avg_loss:.3f}, Accuracy = {accuracy:.2f}%')
        
        generated_sequence = charnn.generate_from_model(model, subset_text[0],
                                                        seq_len*(subset_end-subset_start),
                                                        (char_to_idx,idx_to_char), T=0.1)
        # Stop if we've successfully memorized the small dataset.
        print(generated_sequence)
        if generated_sequence == subset_text:
            break

# Test successful overfitting
test.assertGreater(epoch_result.accuracy, 99)
test.assertEqual(generated_sequence, subset_text)
Epoch #1: Avg. loss = 3.942, Accuracy = 17.58%
Too                                                                                                                                                                                                                                                             

Epoch #25: Avg. loss = 0.184, Accuracy = 99.61%
TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    I would not tell you what I would, my lord.
    Faith, yes:
    Strangers and foes do sunder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HE

OK, so training works - we can memorize a short sequence. Next on the agenda is to split our full dataset into a training and test sets of batched sequences.

In [15]:
# Full dataset definition
vocab_len = len(char_to_idx)
seq_len = 64
batch_size = 256
train_test_ratio = 0.9
num_samples = (len(corpus) - 1) // seq_len
num_train = int(train_test_ratio * num_samples)

samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)

ds_train = torch.utils.data.TensorDataset(samples[:num_train], labels[:num_train])
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=False, drop_last=True)

ds_test = torch.utils.data.TensorDataset(samples[num_train:], labels[num_train:])
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, drop_last=True)

print(f'Train: {len(dl_train):3d} batches, {len(dl_train)*batch_size*seq_len:7d} chars')
print(f'Test:  {len(dl_test):3d} batches, {len(dl_test)*batch_size*seq_len:7d} chars')
Train: 348 batches, 5701632 chars
Test:   38 batches,  622592 chars

We'll now train a much larger model on our large dataset. You'll need a GPU for this part.

The code blocks below will train the model and save checkpoints containing the training state and the best model parameters to a file. This allows you to stop training and resume it later from where you left.

Note that you can use the main.py script provided within the assignment folder to run this notebook from the command line as if it were a python script by using the run-nb subcommand. This allows you to train your model using this notebook without starting jupyter. You can combine this with srun or sbatch to run the notebook with a GPU on the course servers.

In [16]:
# Full training definition
lr = 0.001
num_epochs = 50

in_dim = out_dim = vocab_len
hidden_dim = 512
n_layers = 3
dropout = 0.5
checkpoint_file = 'checkpoints/rnn'
max_batches = 300
early_stopping = 5

model = charnn.MultilayerGRU(in_dim, hidden_dim, out_dim, n_layers, dropout)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
trainer = RNNTrainer(model, loss_fn, optimizer, device)

TODO:

  • Implement the fit() method of the Trainer class. You can reuse the implementation from HW2, but make sure to implement early stopping and checkpoints.
  • Implement the test_epoch() and test_batch() methods of the RNNTrainer class in the hw3/training.py module.
  • Run the following block to train.
In [17]:
from cs236605.plot import plot_fit

def post_epoch_fn(epoch, test_res, train_res, verbose):
    # Update learning rate
    scheduler.step(test_res.accuracy)
    # Sample from model to show progress
    if verbose:
        start_seq = "ACT I."
        generated_sequence = charnn.generate_from_model(
            model, start_seq, 100, (char_to_idx,idx_to_char), T=0.5
        )
        print(generated_sequence)

# Train, unless final checkpoint is found
checkpoint_file_final = f'{checkpoint_file}_final.pt'
if os.path.isfile(checkpoint_file_final):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    saved_state = torch.load(checkpoint_file_final, map_location=device)
    model.load_state_dict(saved_state['model_state'])
else:
    try:
        # Print pre-training sampling
        print(charnn.generate_from_model(model, "ACT I.", 100, (char_to_idx,idx_to_char), T=0.5))

        fit_res = trainer.fit(dl_train, dl_test, num_epochs, max_batches=max_batches,
                              post_epoch_fn=post_epoch_fn, early_stopping=early_stopping,
                              checkpoints=checkpoint_file, print_every=1)
        
        fig, axes = plot_fit(fit_res)
    except KeyboardInterrupt as e:
        print('\n *** Training interrupted by user')
*** Loading final checkpoint file checkpoints/rnn_final.pt instead of training

Generating a work of art

Armed with our fully trained model, let's generate the next Hamlet! You should experiment with modifying the sampling temperature and see what happens.

TODO: Specify the generation parameters in the part1_generation_params() function within the hw3/answers.py module.

In [18]:
import hw3.answers

start_seq, temperature = hw3.answers.part1_generation_params()

generated_sequence = charnn.generate_from_model(
    model, start_seq, 10000, (char_to_idx,idx_to_char), T=temperature
)
print(len(generated_sequence))
print(generated_sequence)
10000
ACTIINESS and CAIUS

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [19]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

Why do we split the corpus into sequences instead of training on the whole text?

In [20]:
display_answer(hw3.answers.part1_q1)

When we have very long sequences like whole text, RNNs can face the problem of vanishing gradients. When attempting to back-propagate across very long input sequences may result in vanishing gradients, and in turn, an unlearnable model.

In addition, long sequences may result in the problem of very long training times.

Question 2

How is it possible that the generated text clearly shows memory longer than the sequence length?

In [21]:
display_answer(hw3.answers.part1_q2)

Basically the memory depends on the hidden state and not the sequence len. Moreover, the hidden states are passed between one batch to another thus may "remember" more than a single sequence length.

In addition, the sequence length during training doesn't limit the sequence length the module can generate because the memory of the module is saved between sequences wich mean the module memory doesn't contain sequences only but learns the order of sequences and how they are fed to the module, which cause the modue to be able to generate large text.

Question 3

Why are we not shuffling the order of batches when training?

In [22]:
display_answer(hw3.answers.part1_q3)

Since we want our module to be able to generate full text and not only sequence, we should be implementing a stateful module and not a stateless one. Therofore, the order of the sequence is importnat and the hidden state is propogated between the batches in the same epoch. In this case the sequence memory will persists across sequences and if sequence B is fed after sequence A, we want the network to evaluate sequence B with memory of what was in sequence A. Therfore the order of the train data is important to be able to generate long sequences even when we train on small sequence length.

Question 4

  1. Why do we lower the temperature for sampling (compared to the default of $1.0$ when training)?
  2. What happens when the temperature is very high and why?
  3. What happens when the temperature is very low and why?
In [23]:
display_answer(hw3.answers.part1_q4)
  1. The smaller the temprature is the larger the values that we apply the softmax on. Performing softmax on larger values makes the RNN network more confident and more conservative in its samples (less input is needed to activate the output layer & less likely to sample from unlikely candidates). This will cause the network to generate "safe" guesses, which is good for sampling.

In the other hand, using a higher temperature produces a softer probability distribution over the classes, resulting in more diversity. This will cause the network start generating "riskier" guesses, which is better for trainning.

2.When the temperature is close to 1 we can see a lot of mistakes in words (words without meaning). Words being sampled with strange letters. When the temperature is very high (10 for example), we get a gebrish text with no words and no sentences, completely random letters.

This is caused because the network try to sample "creative" and "riskier" guesses according to the expilination in 1. this cause the network to sample incoorect words instead using known ones.

3.When we use a very small temperature, we can see that the network repeat common words like (I, will, not the). Every sentence start in the same words, and no "strange" and "hard" words are used only primitive words that is common.

This is caused because the network try to sample in a "safe" manner, therefore it samples a lot of known words.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bm}[1]{{\bf #1}} \newcommand{\bb}[1]{\bm{\mathrm{#1}}} $$

Part 2: Variational Autoencoder

In this part we will learn to generate new data using a special type of autoencoder model which allows us to sample from it's latent space. We'll implement and train a VAE and use it to generate new images.

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Obtaining the dataset

Let's begin by downloading a dataset of images that we want to learn to generate. We'll use the Labeled Faces in the Wild (LFW) dataset which contains many labels faces of famous individuals.

We're going to train our generative model to generate a specific face, not just any face. Since the person with the most images in this dataset is former president George W. Bush, we'll set out to train a Bush Generator :)

However, if you feel adventurous and/or prefer to generate something else, feel free to edit the PART2_CUSTOM_DATA_URL variable in hw3/answers.py.

In [2]:
import cs236605.plot as plot
import cs236605.download
from hw3.answers import PART2_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs236605.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/thaer/.pytorch-datasets/lfw-bush.zip exists, skipping download.
Extracting /home/thaer/.pytorch-datasets/lfw-bush.zip...
Extracted 531 to /home/thaer/.pytorch-datasets/lfw/George_W_Bush

Create a Dataset object that will load the extraced images:

In [3]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [4]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(10,5), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
In [5]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])

The Variational Autoencoder

An autoencoder is a model which learns a representation of data in an unsupervised fashion (i.e without any labels). Recall it's general form from the lecture:

An autoencoder maps an instance $\bb{x}$ to a latent-space representation $\bb{z}$. It has an encoder part, $\Phi_{\bb{\alpha}}(\bb{x})$ (a neural net with parameters $\bb{\alpha}$) and a decoder part, $\Psi_{\bb{\beta}}(\bb{z})$ (a neural net with parameters $\bb{\beta}$).

While autoencoders can learn useful representations, generally it's hard to use them as generative models because there's no distribution we can sample from in the latent space. In other words, we have no way to choose a point $\bb{z}$ in the latent space such that $\Psi(\bb{z})$ will end up on the data manifold in the instance space.

The variational autoencoder (VAE), first proposed by Kingma and Welling, addresses this issue by taking a probabilistic perspective. Briefly, a VAE model can be described as follows.

We define, in Baysean terminology,

  • The prior distribution $p(\bb{Z})$ on points in the latent space.
  • The likelihood distribution of a sample $\bb{X}$ given a latent-space representation: $p(\bb{X}|\bb{Z})$.
  • The posterior distribution of points in the latent spaces given a specific instance: $p(\bb{Z}|\bb{X})$.
  • The evidence distribution $p(\bb{X})$ which is the distribution of the instance space due to the generative process.

To create our variational decoder we'll further specify:

  • A parametric likelihood distribution, $p _{\bb{\beta}}(\bb{X} | \bb{z}) = \mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$. The interpretation is that given a latent $\bb{z}$, we map it to a point normally distributed around the point calculated by our decoder neural network. Note that here $\sigma^2$ is a hyperparameter while $\vec{\beta}$ represents the network parameters.
  • A fixed latent-space prior distribution of $p(\bb{Z}) = \mathcal{N}(\bb{0},\bb{I})$.

This setting allows us to generate a new instance $\bb{x}$ by sampling $\bb{z}$ from the multivariate normal distribution, obtaining the instance-space mean $\Psi _{\bb{\beta}}(\bb{z})$ using our decoder network, and then sampling $\bb{x}$ from $\mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$.

Our variational encoder will approximate the posterior with a parametric distribution $q _{\bb{\alpha}}(\bb{Z} | \bb{x}) \sim \mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$. The interpretation is that our encoder neural network, $\Phi_{\vec{\alpha}}(\bb{x})$, calculates the mean and variance of the posterior distribution, and samples $\bb{z}$ based on them. An important nuance here is that our network can't contain any stochastic elements that depend on the model parameters, otherwise we won't be able to back-propagate to those parameters. So sampling $\bb{z}$ from $\mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$ is not an option. The solution is to use what's known as the reparametrization trick: sample from an isotropic Gaussian, i.e. $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ (which doesn't depend on trainable parameters), and calculate the latent representation as $\bb{z} = \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{u}\odot\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})$.

To train a VAE model, we would like to maximize the evidence, $p(\bb{X})$, because $ p(\bb{X}) = \int p(\bb{X}|{\bb{z}})p(\bb{z})d\bb{z} $ thus maximizing the likelihood of generated instances from over the entire latent space.

The VAE loss can therefore be stated as minimizing $\mathcal{L} = -\mathbb{E}_{\bb{x}} \log p(\bb{X})$. As we saw in the lecture, this expectation is intractable, but we can obtain a lower-bound for $p(\bb{X})$ (the evidence lower bound, "ELBO"):

$$ \log p(\bb{X}) \ge \mathbb{E} _{\bb{z} \sim q _{\bb{\alpha}} }( \log p _{\bb{\beta}}(\bb{X} | \bb{z}) ) - \mathcal{D} _{\mathrm{KL}}\left(q _{\bb{\alpha}}(\bb{Z} | \bb{X})\,\left\|\, p(\bb{Z} )\right.\right) $$

where $ \mathcal{D} _{\mathrm{KL}}(q\left\|\right.p) = \mathbb{E}_{\bb{z}\sim q}\left[ \log \frac{q(\bb{Z})}{p(\bb{Z})} \right] $ is the Kullback-Liebler divergence, which can be interpreted as the information gained by using the posterior $q(\bb{Z|X})$ instead of the prior distribution $p(\bb{Z})$.

Using the ELBO, the VAE loss becomes, $$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} {\bb{x}} \left[ \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} }\left[ -\log p {\bb{\beta}}(\bb{x} | \bb{z}) \right]

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{x})\,\left|\, p(\bb{Z} )\right.\right) \right]. $$

By remembering that the likelihood is a Gaussian distribution with a diagonal covariance and by applying the reparametrization trick, we can write the above as

$$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} _{\bb{x}} \left[ \mathbb{E} _{\bb{z} \sim q _{\bb{\alpha}} } \left[ \frac{1}{2\sigma^2}\left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 \right] + \mathcal{D} _{\mathrm{KL}}\left(q _{\bb{\alpha}}(\bb{Z} | \bb{x})\,\left\|\, p(\bb{Z} )\right.\right) \right]. $$

Model Implementation

Obviously our model will have two parts, an encoder and a decoder. Since we're working with images, we'll implement both as deep convolutional networks, where the decoder is a "mirror image" of the encoder implemented with adjoint (AKA transposed) convolutions. Between the encoder CNN and the decoder CNN we'll implement the sampling from the parametric posterior approximator $q_{\bb{\alpha}}(\bb{Z}|\bb{x})$ to make it a VAE model and not just a regular autoencoder (of course, this is not yet enough to create a VAE, since we also need a special loss function which we'll get to later).

First let's implement just the CNN part of the Encoder network (this is not the full $\Phi_{\vec{\alpha}}(\bb{x})$ yet). As usual, it should take an input image and map to a activation volume of a specified depth. We'll consider this volume as the features we extract from the input image. Later we'll use these to create the latent space representation of the input. which will be our latent space representation.

TODO: Implement the EncoderCNN class in the hw3/autoencoder.py module. Implement any CNN architecture you like. If you need "architecture inspiration" you can see e.g. this or this paper.

In [6]:
import hw3.autoencoder as autoencoder

in_channels = 3
out_channels = 1024
encoder_cnn = autoencoder.EncoderCNN(in_channels, out_channels).to(device)
print(encoder_cnn)

h = encoder_cnn(x0)
print(h.shape)

test.assertEqual(h.dim(), 4)
test.assertSequenceEqual(h.shape[0:2], (1, out_channels))
EncoderCNN(
  (cnn): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Conv2d(512, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)
torch.Size([1, 1024, 1, 1])

Now let's implement the CNN part of the Decoder. Again this is not yet the full $\Psi _{\bb{\beta}}(\bb{z})$. It should take an activation volume produced by your EncoderCNN and output an image of the same dimensions as the Encoder's input was. This should be a CNN which is a "mirror image" of the the Encoder. For example, replace convolutions with transposed convolutions, downsampling with up-sampling etc. Consult the documentation of ConvTranspose2D to figure out how to reverse your convolutional layers in terms of input and output dimensions.

TODO: Implement the DecoderCNN class in the hw3/autoencoder.py module.

In [7]:
decoder_cnn = autoencoder.DecoderCNN(in_channels=out_channels, out_channels=in_channels).to(device)
print(decoder_cnn)
x0r = decoder_cnn(h)
print(x0r.shape)

test.assertEqual(x0.shape, x0r.shape)

# Should look like colored noise
T.functional.to_pil_image(x0r[0].cpu().detach())
DecoderCNN(
  (cnn): Sequential(
    (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  )
)
torch.Size([1, 3, 64, 64])
Out[7]:

Let's now implement the full VAE Encoder, $\Phi_{\vec{\alpha}}(\vec{x})$. It will work as follows:

  1. Produce a feature vector $\vec{h}$ from the input image $\vec{x}$.
  2. Use two affine transforms to convert the features into the mean and log-variance of the posterior, i.e. $$ \begin{align}
     \bb{\mu} _{\bb{\alpha}}(\bb{x}) &= \vec{h}\mattr{W}_{\mathrm{h\mu}} + \vec{b}_{\mathrm{h\mu}} \\
     \log\left(\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})\right) &= \vec{h}\mattr{W}_{\mathrm{h\sigma^2}} + \vec{b}_{\mathrm{h\sigma^2}}
    
    \end{align} $$
  3. Use the reparametrization trick to create the latent representation $\vec{z}$.

Note that we model the log of the variance, not the actual variance. The reason is that the log is easier to optimize, since (a) It doesn't have to be positive, and (b) it has a much larger dynamic range. The above formulation is proposed in appendix C of the VAE paper.

TODO: Implement the encode() method in the VAE class within the hw3/autoencoder.py module. You'll also need to define your parameters in __init__().

In [8]:
z_dim = 2
vae = autoencoder.VAE(encoder_cnn, decoder_cnn, x0[0].size(), z_dim).to(device)
print(vae)

z, mu, log_sigma2 = vae.encode(x0)

test.assertSequenceEqual(z.shape, (1, z_dim))
test.assertTrue(z.shape == mu.shape == log_sigma2.shape)

print(f'mu(x0)={list(*mu.detach().cpu().numpy())}, sigma2(x0)={list(*torch.exp(log_sigma2).detach().cpu().numpy())}')

# Sample from q(Z|x)
N = 500
Z = torch.zeros(N, z_dim)
_, ax = plt.subplots()
with torch.no_grad():
    for i in range(500):
        Z[i], _, _ = vae.encode(x0)
        ax.scatter(*Z[i].cpu().numpy())

# Should be close to the above
print('sampled mu', torch.mean(Z, dim=0))
print('sampled sigma2', torch.var(Z, dim=0))
VAE(
  (features_encoder): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2, inplace)
      (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.2, inplace)
      (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): LeakyReLU(negative_slope=0.2, inplace)
      (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): LeakyReLU(negative_slope=0.2, inplace)
      (11): Conv2d(512, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
    )
  )
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace)
      (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace)
      (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU(inplace)
      (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
  )
  (mu_layer): Linear(in_features=1024, out_features=2, bias=True)
  (logvar_layer): Linear(in_features=1024, out_features=2, bias=True)
  (decoder_reconstruct_features): Linear(in_features=2, out_features=1024, bias=True)
)
mu(x0)=[0.11242073, -0.27497584], sigma2(x0)=[1.011631, 0.7942234]
sampled mu tensor([ 0.1128, -0.2261])
sampled sigma2 tensor([0.9796, 0.7056])

Let's now implement the full VAE Decoder, $\Psi _{\bb{\beta}}(\bb{z})$. It will work as follows:

  1. Produce a feature vector $\tilde{\vec{h}}$ from the latent vector $\vec{z}$ using an affine transform.
  2. Reconstruct an image $\tilde{\vec{x}}$ from $\tilde{\vec{h}}$.

TODO: Implement the decode() method in the VAE class within the hw3/autoencoder.py module. You'll also need to define your parameters in __init__(). You may need to also re-run the block above after you implement this.

In [9]:
x0r = vae.decode(z)

test.assertSequenceEqual(x0r.shape, x0.shape)

Our model's forward() function will simply return decode(encode(x)) as well as the calculated mean and log-variance of the posterior.

In [10]:
x0r, mu, log_sigma2 = vae(x0)

test.assertSequenceEqual(x0r.shape, x0.shape)
test.assertSequenceEqual(mu.shape, (1, z_dim))
test.assertSequenceEqual(log_sigma2.shape, (1, z_dim))
T.functional.to_pil_image(x0r[0].detach().cpu())
Out[10]:

Loss Implementation

In practice, since we're using SGD, we'll drop the expectation over $\bb{X}$ and instead sample an instance from the training set and compute a point-wise loss. Similarly, we'll drop the expectation over $\bb{Z}$ by sampling from $q_{\vec{\alpha}}(\bb{Z}|\bb{x})$. Additionally, because the KL divergence is between two Gaussian distributions, there is a closed-form expression for it. These points bring us to the following point-wise loss:

$$ \ell(\vec{\alpha},\vec{\beta};\bb{x}) = \frac{1}{\sigma^2} \left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 + \mathrm{tr}\,\bb{\Sigma} _{\bb{\alpha}}(\bb{x}) + \|\bb{\mu} _{\bb{\alpha}}(\bb{x})\|^2 _2 - d_z - \log\det \bb{\Sigma} _{\bb{\alpha}}(\bb{x}) $$

where $d_z$ is the dimension of the latent space. This pointwise loss is the quantity that we'll compute and minimize with gradient descent.

TODO: Implement the vae_loss() function in the hw3/autoencoder.py module.

In [11]:
from hw3.autoencoder import vae_loss
torch.manual_seed(42)

def test_vae_loss():
    # Test data
    N, C, H, W = 10, 3, 64, 64 
    z_dim = 32
    x  = torch.randn(N, C, H, W)*2 - 1
    xr = torch.randn(N, C, H, W)*2 - 1
    z_mu = torch.randn(N, z_dim)
    z_log_sigma2 = torch.randn(N, z_dim)
    x_sigma2 = 0.9
    
    loss, _, _ = vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)
    
    test.assertAlmostEqual(loss.item(), 58.3234367, delta=1e-3)
    return loss

test_vae_loss()
Out[11]:
tensor(58.3234)

Sampling

The main advantage of a VAE is that it can by used as a generative model by sampling the latent space, since we optimize for a Normal prior $p(\bb{Z})$ in the loss function. Let's now implement this so that we can visualize how our model is doing when we train.

TODO: Implement the sample() method in the VAE class within the hw3/autoencoder.py module.

In [12]:
samples = vae.sample(5)
_ = plot.tensors_as_images(samples)

Training

Time to train!

TODO:

  1. Implement the VAETrainer class in the hw3/training.py module. Make sure to implement the checkpoints feature of the Trainer class if you haven't done so already in Part 1.
  2. Tweak the hyperparameters in the part2_vae_hyperparam() function within the hw3/answers.py module.
In [13]:
import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from hw3.training import VAETrainer
from hw3.answers import part2_vae_hyperparams

torch.manual_seed(42)

# Hyperparams
hp = part2_vae_hyperparams()
batch_size = hp['batch_size']
h_dim = hp['h_dim']
z_dim = hp['z_dim']
x_sigma2 = hp['x_sigma2']
learn_rate = hp['learn_rate']
betas = hp['betas']

# Data
split_lengths = [int(len(ds_gwb)*0.9), int(len(ds_gwb)*0.1)]
ds_train, ds_test = random_split(ds_gwb, split_lengths)
dl_train = DataLoader(ds_train, batch_size, shuffle=True)
dl_test  = DataLoader(ds_test,  batch_size, shuffle=True)
im_size = ds_train[0][0].shape

# Model
encoder = autoencoder.EncoderCNN(in_channels=im_size[0], out_channels=h_dim)
decoder = autoencoder.DecoderCNN(in_channels=h_dim, out_channels=im_size[0])
vae = autoencoder.VAE(encoder, decoder, im_size, z_dim)
vae_dp = DataParallel(vae).to(device)

# Optimizer
optimizer = optim.Adam(vae.parameters(), lr=learn_rate, betas=betas)

# Loss
def loss_fn(x, xr, z_mu, z_log_sigma2):
    return autoencoder.vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)

# Trainer
trainer = VAETrainer(vae_dp, loss_fn, optimizer, device)
checkpoint_file = 'checkpoints/vae'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show model and hypers
print(vae)
print(hp)
VAE(
  (features_encoder): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2, inplace)
      (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.2, inplace)
      (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): LeakyReLU(negative_slope=0.2, inplace)
      (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): LeakyReLU(negative_slope=0.2, inplace)
      (11): Conv2d(512, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
    )
  )
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace)
      (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace)
      (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU(inplace)
      (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    )
  )
  (mu_layer): Linear(in_features=1024, out_features=128, bias=True)
  (logvar_layer): Linear(in_features=1024, out_features=128, bias=True)
  (decoder_reconstruct_features): Linear(in_features=128, out_features=1024, bias=True)
)
{'batch_size': 32, 'h_dim': 1024, 'z_dim': 128, 'x_sigma2': 0.001, 'learn_rate': 0.001, 'betas': (0.5, 0.999)}

TODO:

  1. Run the following block to train. It will sample some images from your model every few epochs so you can see the progress.
  2. When you're satisfied with your results, rename the checkpoints file by adding _final. When you run the main.py script to generate your submission, the final checkpoints file will be loaded instead of running training.
In [14]:
import IPython.display

def post_epoch_fn(epoch, train_result, test_result, verbose):
    # Plot some samples if this is a verbose epoch
    if verbose:
        samples = vae.sample(n=5)
        fig, _ = plot.tensors_as_images(samples, figsize=(6,2))
        IPython.display.display(fig)
        plt.close(fig)

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    checkpoint_file = checkpoint_file_final
else:
    res = trainer.fit(dl_train, dl_test,
                      num_epochs=200, early_stopping=20, print_every=10,
                      checkpoints=checkpoint_file,
                      post_epoch_fn=post_epoch_fn)
    
# Plot images from best model
saved_state = torch.load(f'{checkpoint_file}.pt', map_location=device)
vae_dp.load_state_dict(saved_state['model_state'])
print('*** Images Generated from best model:')
fig, _ = plot.tensors_as_images(vae_dp.module.sample(n=15), nrows=3, figsize=(6,6))
*** Loading final checkpoint file checkpoints/vae_final instead of training
*** Images Generated from best model:

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [15]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

What does the $\sigma^2$ hyperparameter (x_sigma2 in the code) do? Explain the effect of low and high values.

In [16]:
display_answer(hw3.answers.part2_q1)

The hyperparameter $\sigma^2$ controls the division ratios and calibrate the data loss and KLD loss in the total loss, which will be translated to a variance in the generated photos by the module. The higher the $\sigma^2$ the smaller the data loss part in the loss function and the KLD loss will be more dominant, this will cause the module to be more calibrated to the latent space part and generate similar photos (like an average photo). However, when the $\sigma^2$ is smaller, the data loss part in the total loss is bigger, the module will try to give more focus on the "original" photo that were encoded to the random z that we generated, and since vectors are sampled randomly, diffrent samples will cause diffrent decoded photos with a variance.

Question 2

  1. Explain the purpose of both parts of the VAE loss term - reconstruction loss and KL divergence loss.
  2. How is the latent-space distribution affected by the KL loss term?
  3. What's the benefit of this?
In [17]:
display_answer(hw3.answers.part2_q2)

In the VAE loss term, we sum up two separate losses:

  1. The generative loss, which is a mean squared error that measures how accurately the network reconstructed the images. This term encourages the decoder to learn to reconstruct the data. If the decoder’s output does not reconstruct the data well, statistically we say that the decoder parameterizes a likelihood distribution that does not place much probability mass on the true data.

  2. A latent loss, which is the KL divergence that measures how closely the latent variables match a unit gaussian. We can think about it like regularization term. This is the Kullback-Leibler divergence between the encoder’s distribution (posterior distribution of points in the latent spaces given a specific instance) $p(\bb{Z}|\bb{X})$ and the prior distribution $p(\bb{Z})$ This divergence measures how much information is lost, when using q to represent p. It is one measure of how close q is to p.

Since we sample in the latent space with a normally distributed gaussain $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ and then apply the reparametrization trick The KL divergence help to optimize the distribution of X so that they are more tightly packed around the origin. So we are going to optimize so that the P distribution look the most like the N(0,1) distribution (a gaussian distribution located around the origin). This have a big benefits that the samples that we take in the latent space from the normal distribution will be likelly mapped by the decoder to an close image to what we trained on.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bm}[1]{{\bf #1}} \newcommand{\bb}[1]{\bm{\mathrm{#1}}} $$

Part 3: Generative Adversarial Networks

In this part we will implement and train a generative adversarial network and apply it to the task of image generation.

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Obtaining the dataset

We'll use the same data as in Part 2.

But again, to use a custom dataset, edit the PART3_CUSTOM_DATA_URL variable in hw3/answers.py.

In [2]:
import cs236605.plot as plot
import cs236605.download
from hw3.answers import PART3_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs236605.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/thaer/.pytorch-datasets/lfw-bush.zip exists, skipping download.
Extracting /home/thaer/.pytorch-datasets/lfw-bush.zip...
Extracted 531 to /home/thaer/.pytorch-datasets/lfw/George_W_Bush

Create a Dataset object that will load the extraced images:

In [3]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [4]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(10,5), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
In [5]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])

Generative Adversarial Nets (GANs)

GANs, first proposed in a paper by Ian Goodfellow in 2014 are today arguably the most popular type of generative model. GANs are currently producing state of the art results in generative tasks over many different domains.

In a GAN model, two different neural networks compete against each other: A generator and a discriminator.

  • The Generator, which we'll denote as $\Psi _{\bb{\gamma}} : \mathcal{U} \rightarrow \mathcal{X}$, maps a latent-space variable $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ to an instance-space variable $\bb{x}$ (e.g. an image). Thus a parametric evidence distribution $p_{\bb{\gamma}}(\bb{X})$ is generated, which we typically would like to be as close as possible to the real evidence distribution, $p(\bb{X})$.

  • The Discriminator, $\Delta _{\bb{\delta}} : \mathcal{X} \rightarrow [0,1]$, is a network which, given an instance-space variable $\bb{x}$, returns the probability that $\bb{x}$ is real, i.e. that $\bb{x}$ was sampled from $p(\bb{X})$ and not $p_{\bb{\gamma}}(\bb{X})$.

Training GANs

The generator is trained to generate "fake" instances which will maximally fool the discriminator into returning that they're real. Mathematically, the generator's parameters $\bb{\gamma}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

The discriminator is trained to classify between real images, coming from the training set, and fake images generated by the generator. Mathematically, the discriminator's parameters $\bb{\delta}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

These two competing objectives can thus be expressed as the following min-max optimization: $$ \min _{\bb{\gamma}} \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

A key insight into GANs is that we can interpret the above maximum as the loss with respect to $\bb{\gamma}$:

$$ L({\bb{\gamma}}) = \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

This means that the generator's loss function trains together with the generator itself in an adversarial manner. In contrast, when training our VAE we used a fixed L2 norm as a data loss term.

Model Implementation

We'll now implement a Deep Convolutional GAN (DCGAN) model. See the DCGAN paper for architecture ideas and tips for training.

TODO: Implement the Discriminator class in the hw3/gan.py module. If you wish you can reuse the EncoderCNN class from the VAE model as the first part of the Discriminator.

In [6]:
import hw3.gan as gan

dsc = gan.Discriminator(in_size=x0[0].shape).to(device)
print(dsc)

d0 = dsc(x0)
print(d0.shape)

test.assertSequenceEqual(d0.shape, (1,1))
torch.Size([3, 64, 64])
Discriminator(
  (discriminator): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)
torch.Size([1, 1])

TODO: Implement the Generator class in the hw3/gan.py module. If you wish you can reuse the DecoderCNN class from the VAE model as the last part of the Generator.

In [7]:
z_dim = 128
gen = gan.Generator(z_dim, 4).to(device)
print(gen)

z = torch.randn(1, z_dim).to(device)
xr = gen(z)
print(xr.shape)

test.assertSequenceEqual(x0.shape, xr.shape)
Generator(
  (generator): Sequential(
    (0): ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  )
)
torch.Size([1, 3, 64, 64])

Loss Implementation

Let's begin with the discriminator's loss function. Based on the above we can flip the sign and say we want to update the Discriminator's parameters $\bb{\delta}$ so that they minimize the expression $$

  • \mathbb{E} {\bb{x} \sim p(\bb{X}) } \log \Delta {\bb{\delta}}(\bb{x}) \, - \, \mathbb{E} {\bb{z} \sim p(\bb{Z}) } \log (1-\Delta {\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

We're using the Discriminator twice in this expression; once to classify data from the real data distribution and once again to classify generated data. Therefore our loss should be computed based on these two terms. Notice that since the discriminator returns a probability, we can formulate the above as two cross-entropy losses.

GANs are notoriously diffucult to train. One common trick for improving GAN stability during training is to make the classification labels noisy for the discriminator. This can be seen as a form of regularization, to help prevent the discriminator from overfitting.

We'll incorporate this idea into our loss function. Instead of labels being equal to 0 or 1, we'll make them "fuzzy", i.e. random numbers in the ranges $[0\pm\epsilon]$ and $[1\pm\epsilon]$.

TODO: Implement the discriminator_loss_fn() function in the hw3/gan.py module.

In [8]:
from hw3.gan import discriminator_loss_fn
torch.manual_seed(42)

y_data = torch.rand(10) * 10
y_generated = torch.rand(10) * 10

loss = discriminator_loss_fn(y_data, y_generated, data_label=1, label_noise=0.3)
print(loss)

test.assertAlmostEqual(loss.item(), 6.4808731, delta=1e-5)
tensor(6.4809)

Similarly, the generator's parameters $\bb{\gamma}$ should minimize the expression $$ -\mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )) $$

which can also be seen as a cross-entropy term.

TODO: Implement the generator_loss_fn() function in the hw3/gan.py module.

In [9]:
from hw3.gan import generator_loss_fn
torch.manual_seed(42)

y_generated = torch.rand(20) * 10

loss = generator_loss_fn(y_generated, data_label=1)
print(loss)

test.assertAlmostEqual(loss.item(), 0.0222969, delta=1e-3)
tensor(0.0223)

Sampling

Sampling from a GAN is straightforward, since it learns to generate data from an isotropic Gaussian latent space distribution.

There is an important nuance however. Sampling is required during the process of training the GAN, since we generate fake images to show the discriminator. As you'll seen in the next section, in some cases we'll need our samples to have gradients (i.e., to be part of the Generator's computation graph).

TODO: Implement the sample() method in the Generator class within the hw3/gan.py module.

In [10]:
samples = gen.sample(5, with_grad=False)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNone(samples.grad_fn)
_ = plot.tensors_as_images(samples.cpu())

samples = gen.sample(5, with_grad=True)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNotNone(samples.grad_fn)

Training

Training GANs is a bit different since we need to train two models simultaneously, each with it's own separate loss function and optimizer. We'll implement the training logic as a function that handles one batch of data and updates both the discriminator and the generator based on it.

As mentioned above, GANs are considered hard to train. To get some ideas and tips you can see this paper, this list of "GAN hacks" or just do it the hard way :)

TODO:

  1. Implement the train_batch function in the hw3/gan.py module.
  2. Tweak the hyperparameters in the part3_gan_hyperparam() function within the hw3/answers.py module.
In [11]:
import torch.optim as optim
from torch.utils.data import DataLoader
from hw3.answers import part3_gan_hyperparams

torch.manual_seed(42)

# Hyperparams
hp = part3_gan_hyperparams()
batch_size = hp['batch_size']
z_dim = hp['z_dim']

# Data
dl_train = DataLoader(ds_gwb, batch_size, shuffle=True)
im_size = ds_gwb[0][0].shape

# Model
dsc = gan.Discriminator(im_size).to(device)
gen = gan.Generator(z_dim, featuremap_size=4).to(device)

# Optimizer
def create_optimizer(model_params, opt_params):
    opt_params = opt_params.copy()
    optimizer_type = opt_params['type']
    opt_params.pop('type')
    return optim.__dict__[optimizer_type](model_params, **opt_params)
dsc_optimizer = create_optimizer(dsc.parameters(), hp['discriminator_optimizer'])
gen_optimizer = create_optimizer(gen.parameters(), hp['generator_optimizer'])

# Loss
def dsc_loss_fn(y_data, y_generated):
    return gan.discriminator_loss_fn(y_data, y_generated, hp['data_label'], hp['label_noise'])

def gen_loss_fn(y_generated):
    return gan.generator_loss_fn(y_generated, hp['data_label'])

# Training
checkpoint_file = 'checkpoints/gan'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show hypers
print(hp)
torch.Size([3, 64, 64])
{'batch_size': 32, 'z_dim': 128, 'data_label': 1, 'label_noise': 0.1, 'discriminator_optimizer': {'type': 'Adam', 'lr': 0.0002}, 'generator_optimizer': {'type': 'Adam', 'lr': 0.0002}}

TODO:

  1. Run the following block to train. It will sample some images from your model every few epochs so you can see the progress.
  2. When you're satisfied with your results, rename the checkpoints file by adding _final. When you run the main.py script to generate your submission, the final checkpoints file will be loaded instead of running training.
In [12]:
import IPython.display
import tqdm
from hw3.gan import train_batch

num_epochs = 100

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    num_epochs = 0
    gen = torch.load(f'{checkpoint_file_final}.pt', map_location=device)
    checkpoint_file = checkpoint_file_final

for epoch_idx in range(num_epochs):
    # We'll accumulate batch losses and show an average once per epoch.
    dsc_losses = []
    gen_losses = []
    print(f'--- EPOCH {epoch_idx+1}/{num_epochs} ---')
    
    with tqdm.tqdm(total=len(dl_train.batch_sampler), file=sys.stdout) as pbar:
        for batch_idx, (x_data, _) in enumerate(dl_train):
            x_data = x_data.to(device)
            dsc_loss, gen_loss = train_batch(
                dsc, gen,
                dsc_loss_fn, gen_loss_fn,
                dsc_optimizer, gen_optimizer,
                x_data)
            dsc_losses.append(dsc_loss)
            gen_losses.append(gen_loss)
            pbar.update()

    dsc_avg_loss, gen_avg_loss = np.mean(dsc_losses), np.mean(gen_losses)
    print(f'Discriminator loss: {dsc_avg_loss}')
    print(f'Generator loss:     {gen_avg_loss}')
        
    samples = gen.sample(5, with_grad=False)
    fig, _ = plot.tensors_as_images(samples.cpu(), figsize=(6,2))
    IPython.display.display(fig)
    plt.close(fig)
--- EPOCH 1/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.26it/s]
Discriminator loss: 0.4022838762577842
Generator loss:     3.6342799803789925
--- EPOCH 2/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.24it/s]
Discriminator loss: 0.060317304560585934
Generator loss:     6.479255900663488
--- EPOCH 3/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.26it/s]
Discriminator loss: 0.03782639088218703
Generator loss:     7.151635422426112
--- EPOCH 4/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.15it/s]
Discriminator loss: 0.022160114939598477
Generator loss:     7.7161625132841225
--- EPOCH 5/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.27it/s]
Discriminator loss: 0.027402905089890257
Generator loss:     7.482312258552103
--- EPOCH 6/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.06it/s]
Discriminator loss: 0.043317326399333334
Generator loss:     9.89507178699269
--- EPOCH 7/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.31it/s]
Discriminator loss: 0.0221371830386274
Generator loss:     7.749110249912038
--- EPOCH 8/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.26it/s]
Discriminator loss: 0.0246672196971143
Generator loss:     8.034614478840547
--- EPOCH 9/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.45it/s]
Discriminator loss: 0.023247391862027785
Generator loss:     9.69900504280539
--- EPOCH 10/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.07it/s]
Discriminator loss: 0.0029995923533159144
Generator loss:     11.594942738028134
--- EPOCH 11/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.02it/s]
Discriminator loss: 0.028653493151068687
Generator loss:     7.9471767088946175
--- EPOCH 12/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.14it/s]
Discriminator loss: 0.0272491575930925
Generator loss:     8.102576536290785
--- EPOCH 13/100 ---
100%|██████████| 17/17 [00:02<00:00,  5.71it/s]
Discriminator loss: -0.014067287228124984
Generator loss:     8.62743439393885
--- EPOCH 14/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.34it/s]
Discriminator loss: -0.010119837270501782
Generator loss:     9.86256114174338
--- EPOCH 15/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.35it/s]
Discriminator loss: 0.02586004406432895
Generator loss:     8.529943746678969
--- EPOCH 16/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.13it/s]
Discriminator loss: 0.034082081497592086
Generator loss:     7.517268180847168
--- EPOCH 17/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.24it/s]
Discriminator loss: 0.01741193958065089
Generator loss:     8.549176188076244
--- EPOCH 18/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.45it/s]
Discriminator loss: 0.01591136438005111
Generator loss:     12.289215985466452
--- EPOCH 19/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.11it/s]
Discriminator loss: 0.021831633413539213
Generator loss:     12.017369522767908
--- EPOCH 20/100 ---
100%|██████████| 17/17 [00:03<00:00,  5.51it/s]
Discriminator loss: 0.044148648595985246
Generator loss:     11.444228452794691
--- EPOCH 21/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.06it/s]
Discriminator loss: 0.04460623887751032
Generator loss:     9.658961800967946
--- EPOCH 22/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.11it/s]
Discriminator loss: 0.03852704238584813
Generator loss:     8.985782483044792
--- EPOCH 23/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.84it/s]
Discriminator loss: 0.02146548087544301
Generator loss:     9.652920218075023
--- EPOCH 24/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.46it/s]
Discriminator loss: 0.07703010562588186
Generator loss:     12.977922944461598
--- EPOCH 25/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.38it/s]
Discriminator loss: 0.05186936664668953
Generator loss:     9.10474510753856
--- EPOCH 26/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.07it/s]
Discriminator loss: 0.04818831110263572
Generator loss:     9.511353156145882
--- EPOCH 27/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.07it/s]
Discriminator loss: 0.052439905483933055
Generator loss:     7.823462850907269
--- EPOCH 28/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.12it/s]
Discriminator loss: 0.01880163756911369
Generator loss:     9.437497110927806
--- EPOCH 29/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.19it/s]
Discriminator loss: 0.08108953079756569
Generator loss:     12.201773110558005
--- EPOCH 30/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.97it/s]
Discriminator loss: 0.07465962216477184
Generator loss:     8.944892518660602
--- EPOCH 31/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.09it/s]
Discriminator loss: 0.10846161968348657
Generator loss:     10.04673220129574
--- EPOCH 32/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.13it/s]
Discriminator loss: 0.046988626996822214
Generator loss:     8.177918349995332
--- EPOCH 33/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.18it/s]
Discriminator loss: 0.08862899304093684
Generator loss:     10.522600903230554
--- EPOCH 34/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.19it/s]
Discriminator loss: 0.03626178237883484
Generator loss:     9.881071960224824
--- EPOCH 35/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.44it/s]
Discriminator loss: 0.055667064864845836
Generator loss:     7.367160741020651
--- EPOCH 36/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.46it/s]
Discriminator loss: 0.06483911109321258
Generator loss:     10.38596804001752
--- EPOCH 37/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.06it/s]
Discriminator loss: 0.058702757143798995
Generator loss:     11.708405719083899
--- EPOCH 38/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.17it/s]
Discriminator loss: 0.02440410994869821
Generator loss:     8.158684618332806
--- EPOCH 39/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.22it/s]
Discriminator loss: 0.046541558797745144
Generator loss:     7.462915841270895
--- EPOCH 40/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.12it/s]
Discriminator loss: 0.03760329162811532
Generator loss:     8.998805803411146
--- EPOCH 41/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.07it/s]
Discriminator loss: 0.031060095629928744
Generator loss:     10.768140316009521
--- EPOCH 42/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.46it/s]
Discriminator loss: 0.03701003536801128
Generator loss:     9.96671858955832
--- EPOCH 43/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.21it/s]
Discriminator loss: 0.07419460507876732
Generator loss:     11.037597824545468
--- EPOCH 44/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.27it/s]
Discriminator loss: 0.05102672143017545
Generator loss:     7.939053535461426
--- EPOCH 45/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.06it/s]
Discriminator loss: 0.02913084932986428
Generator loss:     8.136023661669563
--- EPOCH 46/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.30it/s]
Discriminator loss: 0.05055817490553155
Generator loss:     9.42395440269919
--- EPOCH 47/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.39it/s]
Discriminator loss: 0.023596799494150805
Generator loss:     9.703780202304616
--- EPOCH 48/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.33it/s]
Discriminator loss: 0.030125460203956154
Generator loss:     10.451448833241182
--- EPOCH 49/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.33it/s]
Discriminator loss: 0.02489055846543873
Generator loss:     9.535749435424805
--- EPOCH 50/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.30it/s]
Discriminator loss: 0.03878423582543345
Generator loss:     10.403770643122057
--- EPOCH 51/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.15it/s]
Discriminator loss: 0.05763880782486761
Generator loss:     11.011498226838953
--- EPOCH 52/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.41it/s]
Discriminator loss: 0.047397276496185976
Generator loss:     11.10790460249957
--- EPOCH 53/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.24it/s]
Discriminator loss: 0.04060671998954871
Generator loss:     11.847977918737074
--- EPOCH 54/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.25it/s]
Discriminator loss: 0.013352618224042304
Generator loss:     8.680926827823415
--- EPOCH 55/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.24it/s]
Discriminator loss: 0.04174211533630595
Generator loss:     9.133586546953987
--- EPOCH 56/100 ---
100%|██████████| 17/17 [00:02<00:00,  5.70it/s]
Discriminator loss: -0.013932405697072254
Generator loss:     8.309969986186308
--- EPOCH 57/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.38it/s]
Discriminator loss: 0.04455560506047571
Generator loss:     11.780052858240465
--- EPOCH 58/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.35it/s]
Discriminator loss: 0.042005222016835916
Generator loss:     11.220606944140266
--- EPOCH 59/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.36it/s]
Discriminator loss: 0.044789898921461666
Generator loss:     9.957868211409625
--- EPOCH 60/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.22it/s]
Discriminator loss: 0.026172436773777008
Generator loss:     9.146281046025893
--- EPOCH 61/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.14it/s]
Discriminator loss: 0.019368157776839593
Generator loss:     7.496942379895379
--- EPOCH 62/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.17it/s]
Discriminator loss: 0.0009296858573661131
Generator loss:     7.7319605771233055
--- EPOCH 63/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.20it/s]
Discriminator loss: 0.005501605049871346
Generator loss:     7.319698866675882
--- EPOCH 64/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.21it/s]
Discriminator loss: 0.03231088104931747
Generator loss:     8.491457995246439
--- EPOCH 65/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.05it/s]
Discriminator loss: 0.02547282583135016
Generator loss:     9.01593261606553
--- EPOCH 66/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.00it/s]
Discriminator loss: 0.03672037846134866
Generator loss:     8.19250634137322
--- EPOCH 67/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.04it/s]
Discriminator loss: -0.002439247608623084
Generator loss:     7.089772673214183
--- EPOCH 68/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.07it/s]
Discriminator loss: 0.041782958871301484
Generator loss:     7.825737644644344
--- EPOCH 69/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.07it/s]
Discriminator loss: 0.03502945229411125
Generator loss:     11.003360159256879
--- EPOCH 70/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.96it/s]
Discriminator loss: 0.03814027231077061
Generator loss:     9.721833565655876
--- EPOCH 71/100 ---
100%|██████████| 17/17 [00:02<00:00,  7.02it/s]
Discriminator loss: 0.022525475908289936
Generator loss:     8.795534947339226
--- EPOCH 72/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.88it/s]
Discriminator loss: 0.05614198831950917
Generator loss:     12.800697354709401
--- EPOCH 73/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.84it/s]
Discriminator loss: 0.02559424312237431
Generator loss:     14.83996228610768
--- EPOCH 74/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.86it/s]
Discriminator loss: 0.049784058047568094
Generator loss:     10.58527811835794
--- EPOCH 75/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.90it/s]
Discriminator loss: 0.04463378361919347
Generator loss:     8.552692413330078
--- EPOCH 76/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.94it/s]
Discriminator loss: 0.013734047739383052
Generator loss:     11.541373000425452
--- EPOCH 77/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.94it/s]
Discriminator loss: 0.05255987528054153
Generator loss:     9.21201332877664
--- EPOCH 78/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.92it/s]
Discriminator loss: -0.007517707369783346
Generator loss:     10.403826489168054
--- EPOCH 79/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.91it/s]
Discriminator loss: 0.023977108683217976
Generator loss:     10.500497144811293
--- EPOCH 80/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.94it/s]
Discriminator loss: 0.06533729909535717
Generator loss:     10.362654573777142
--- EPOCH 81/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.98it/s]
Discriminator loss: -0.03014084159889642
Generator loss:     16.25366485820097
--- EPOCH 82/100 ---
100%|██████████| 17/17 [00:02<00:00,  6.14it/s]
Discriminator loss: 0.08528617495561347
Generator loss:     11.503966443678912
--- EPOCH 83/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.20it/s]
Discriminator loss: 0.0485588846837773
Generator loss:     10.73743449940401
--- EPOCH 84/100 ---
100%|██████████| 17/17 [00:03<00:00,  5.74it/s]
Discriminator loss: 0.02418605731252362
Generator loss:     10.099281254936667
--- EPOCH 85/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.20it/s]
Discriminator loss: 0.031646001755314714
Generator loss:     9.45245922313017
--- EPOCH 86/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.10it/s]
Discriminator loss: 0.02652003221652087
Generator loss:     7.922471411087933
--- EPOCH 87/100 ---
100%|██████████| 17/17 [00:03<00:00,  5.53it/s]
Discriminator loss: 0.08475906149867703
Generator loss:     10.272227764129639
--- EPOCH 88/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.12it/s]
Discriminator loss: 0.06723713896730367
Generator loss:     9.634733873255113
--- EPOCH 89/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.27it/s]
Discriminator loss: 0.07592289916732732
Generator loss:     10.159999286427217
--- EPOCH 90/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.14it/s]
Discriminator loss: 0.034692903268424904
Generator loss:     7.952407612520106
--- EPOCH 91/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.06it/s]
Discriminator loss: 0.09340258161811267
Generator loss:     11.900830437155332
--- EPOCH 92/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.16it/s]
Discriminator loss: 0.14304472275954835
Generator loss:     10.162169736974379
--- EPOCH 93/100 ---
100%|██████████| 17/17 [00:02<00:00,  5.67it/s]
Discriminator loss: 0.036614916123011536
Generator loss:     9.434652889476103
--- EPOCH 94/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.09it/s]
Discriminator loss: 0.10977952669867698
Generator loss:     7.391819813672234
--- EPOCH 95/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.18it/s]
Discriminator loss: 0.0950633420554154
Generator loss:     9.348790112663718
--- EPOCH 96/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.17it/s]
Discriminator loss: 0.03516111479086034
Generator loss:     9.174052434809068
--- EPOCH 97/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.01it/s]
Discriminator loss: 0.042099163922316885
Generator loss:     7.89229404225069
--- EPOCH 98/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.03it/s]
Discriminator loss: 0.026527422370717806
Generator loss:     8.674463945276598
--- EPOCH 99/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.10it/s]
Discriminator loss: 0.014413932557491696
Generator loss:     8.634821246652042
--- EPOCH 100/100 ---
100%|██████████| 17/17 [00:03<00:00,  6.03it/s]
Discriminator loss: 0.03635896649211645
Generator loss:     8.053174355450798
In [13]:
# Plot images from best or last model
if os.path.isfile(f'{checkpoint_file}.pt'):
    gen = torch.load(f'{checkpoint_file}.pt', map_location=device)
print('*** Images Generated from best model:')
samples = gen.sample(n=15, with_grad=False).cpu()
fig, _ = plot.tensors_as_images(samples, nrows=3, figsize=(6,6))
*** Images Generated from best model:

Questions

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [14]:
from cs236605.answers import display_answer
import hw3.answers

Question 1

Explain in detail why during training we sometimes need to maintain gradients when sampling from the GAN, and other times we don't. When are they maintained and why? When are they discarded and why?

In [15]:
display_answer(hw3.answers.part3_q1)

The gradients are maintained during generator update (feeding and backpropogation), and discarded during discriminator update hen sampling fake data with the generator.

There is no need to calculate gradients for the generator when we are updating the discriminator. In this case we use the generator to sample fake data to train the discriminator and not to train the generator, therefore we should not accumulate gradients for this operation.

In the other hand, when we update the generator, we use the generator sample to forward and backpropogate to update the generator paremeters, therefore, in this case maintaining the gradients is mandatory to train the generator.

Question 2

  1. When training a GAN to generate images, should we decide to stop training solely based on the fact that the Generator loss is below some threshold? Why or why not?

  2. What does it mean if the discriminator loss remains at a constant value while the generator loss decreases?

In [16]:
display_answer(hw3.answers.part3_q2)
  1. No since the discriminator can be fooled by the generator while the generator generate good fakes. This that the generator is good doesn't say the discriminator is also good to distinguish between fake and real data.

  2. This that the generator loss decrease mean that the generator is performing better, generating better and better fakes. If in this case the discriminator loss stay const, this mean that the disriminator is really good trained (in case loss is small), even when feeding it with better and better "fakes", it still able to distinguish with same accuracy.

Question 3

Compare the results you got when generating images with the VAE to the GAN results. What's the main difference and what's causing it?

In [17]:
display_answer(hw3.answers.part3_q3)

In my implimintation images from VAE looks better compared to images from GAN. Images from GAN looks more noisy and not smooth.

I expected the GAN give better results, therefore this can be architecture issue.